{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 语言模型\n",
    "------\n",
    "## 1 简介\n",
    "- pytorch实现LSTM训练语言模型\n",
    "- 使用的数据集：\n",
    "    - bobsue.lm.train.txt：语言模型训练数据（LMTRAIN）\n",
    "    - bobsue.lm.dev.txt：语言模型验证数据（LMDEV）\n",
    "    - bobsue.lm.test.txt：语言模型测试数据（LMTEST）\n",
    "    - bobsue.prevsent.train.tsv：基于上文的语言模型训练数据（PREVSENTTRAIN）\n",
    "    - bobsue.prevsent.dev.tsv：基于上文的上文语言模型验证数据（PREVSENTDEV）\n",
    "    - bobsue.prevsent.test.tsv：基于上文的上文语言模型测试数据（PREVSENTTEST）\n",
    "    - bobsue.voc.txt：词汇表文件，每行是一个单词\n",
    "    - lm文件中的每一行都包含一个故事中的句子。prev文件中的每一行都包含一故事中的一个句子，tab，然后是故事中的下一个句子。注意：prevsent文件中每一行的第二个字段与相应的lm文件中的对应行相同。( 也就是说：cut -f 2 bobsue.prevsent.x.tsv与bobsue.lm.x.txt相同）完整的词汇表包含在文件bobsue.voc.txt中，每一行是一个单词。在这个任务中不会出现未知单词。\n",
    "- 评估\n",
    "    - 我们使用单词预测准确率作为主要评估指标而非困惑度(perplexity)。因为当你试图比较某些损失函数时，perplexity不太好用。\n",
    "\n",
    "-----\n",
    "## 2 具体内容\n",
    "### 用Log Loss训练LSTM模型\n",
    "- 实现一个基于LSTM的语言模型。具体为：\n",
    "    - 对每个当前的hidden state做一个线性变化和softmax处理，预测下一个单词。\n",
    "    - 使用Log Loss (Cross Entropy Loss)来训练该模型。使用**EVALLM**的步骤来评估模型。\n",
    "    - 汇报模型训练结果和代码。你的单词预测准确率应该能够达到30%以上。\n",
    "- 要求\n",
    "    - 至少训练10个epoch\n",
    "    - 可以使用不同的模型参数。建议使用一层LSTM，200 hidden dimension作为词向量和LSTM hidden state的大小。\n",
    "    - 模型参数的初始值可以随机设定\n",
    "    - 输入和输出层的word embedding参数可以不一样，当然你也可以尝试把他们设置成一样。\n",
    "    - 使用Adam或者SGD等optimizer来优化模型。\n",
    "    - 在提交报告的时候请尽可能详细描述你的所有模型参数。\n",
    "----\n",
    "### 错误分析\n",
    "- 请在你的代码中添加一项功能，可以展示出你模型预测错误的单词，将标准答案单词和模型预测的单词分别打印出来。\n",
    "- 请写下你的模型最常见的35个预测错误(正确答案是a，模型预测了b)。\n",
    "- 通过观察这些常见的错误，将错误分类。你不需要将每个错误都分类，不过建议同学们花点时间观察自己模型的错误，看看他们是否有一定的相关性。大家可以尝试从以下角度思考错误类型：\n",
    "    - 为什么你的模型会预测出这个单词？\n",
    "    - 模型怎么样才能做得更好？这个模型犯的错误是很接近正确答案的吗？如果是的话，这个错误答案与正确答案有何相似之处？\n",
    "    - 把这35个预测错误归类成你定义的错误类别。讨论一下你的模型在哪些方面做得比较好，哪些方面做的不好。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "\n",
    "from collections import Counter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 数据文件\n",
    "word_file = './data/bobsue.voc.txt'\n",
    "train_file = './data/bobsue.lm.train.txt'\n",
    "test_file = './data/bobsue.lm.test.txt'\n",
    "dev_file = './data/bobsue.lm.dev.txt'\n",
    "\n",
    "BATCH_SIZE = 32       # 批次大小\n",
    "EMBEDDING_DIM = 200   # 词向量维度\n",
    "HIDDEN_DIM = 200      # 隐含层\n",
    "GRAD_CLIP = 5.        # 梯度截断值\n",
    "EPOCHS = 20 \n",
    "LEARN_RATE = 0.01     # 初始学习率\n",
    "\n",
    "BEST_VALID_ACC = 0.     # 初始验证集上的损失值，设为最大\n",
    "MODEL_PATH = \"lm-best-dim{}.pth\"   # 模型名称\n",
    "USE_CUDA = torch.cuda.is_available()    # 是否使用GPU\n",
    "NUM_CUDA = torch.cuda.device_count()    # GPU数量"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1 数据预处理\n",
    "### 1.1 读取数据文件，构建词汇集、word2idx、idx2word。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_word_set(filename):\n",
    "    with open(filename, \"r\", encoding=\"utf-8\") as f:\n",
    "        word_set = set([line.strip() for line in f])\n",
    "    return word_set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "word_set = load_word_set(word_file)\n",
    "word2idx = {w:i for i, w in enumerate(word_set, 1)}\n",
    "idx2word = {i:w for i, w in enumerate(word_set, 1)}\n",
    "\n",
    "# 将pad的索引设置为0并添加到词表\n",
    "PAD_IDX = 0\n",
    "word2idx[\"<pad>\"] = PAD_IDX\n",
    "idx2word[PAD_IDX] = \"<pad>\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.2 训练、验证、测试数据准备\n",
    "- 将数据处理成模型可以接收的格式"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_corpus(filename):\n",
    "    \"\"\"读取数据集，返回句子列表\"\"\"\n",
    "    with open(filename, \"r\", encoding=\"utf-8\") as f:\n",
    "        sentences = [line.strip() for line in f]\n",
    "    return sentences\n",
    "\n",
    "def sentences2words(sentences):\n",
    "    \"\"\"将句子列表转换成单词列表\"\"\"\n",
    "    return [w for s in sentences for w in s.split()]\n",
    "\n",
    "def max_sentence_num(sentences):\n",
    "    \"\"\"返回最长句子单词数量\"\"\"\n",
    "    return max([len(s.split()) for s in sentences ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 各数据集的句子列表\n",
    "train_sentences = load_corpus(train_file)\n",
    "dev_sentences = load_corpus(dev_file)\n",
    "test_sentences = load_corpus(test_file)\n",
    "\n",
    "# 各数据集的单词列表\n",
    "train_words = sentences2words(train_sentences)\n",
    "dev_words = sentences2words(dev_sentences)\n",
    "test_words = sentences2words(test_sentences)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "训练集句子数: 6036, 单词数: 71367.\n",
      "验证集句子数: 750, 单词数: 8707.\n",
      "测试集句子数: 750, 单词数: 8809.\n",
      "--------------------------------------------------\n",
      "训练集最长句子单词个数： 21\n",
      "验证集最长句子单词个数： 20\n",
      "测试集最长句子单词个数： 21\n"
     ]
    }
   ],
   "source": [
    "# 查看处理后训练集、验证集、测试集的基本情况\n",
    "s = \"{}句子数: {}, 单词数: {}.\"\n",
    "print(s.format(\"训练集\", len(train_sentences), len(train_words)))\n",
    "print(s.format(\"验证集\", len(dev_sentences), len(dev_words)))\n",
    "print(s.format(\"测试集\", len(test_sentences), len(test_words)))\n",
    "\n",
    "print(\"-\"*50)\n",
    "\n",
    "# 这里需要知道各数据集上最长句子的单词📚，以便后面构造单词索引向量的时候设置一个恰当的维度\n",
    "print(\"训练集最长句子单词个数：\", max_sentence_num(train_sentences))\n",
    "print(\"验证集最长句子单词个数：\", max_sentence_num(dev_sentences))\n",
    "print(\"测试集最长句子单词个数：\", max_sentence_num(test_sentences))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_x_y(corpus, word2idx, seq_len=21):\n",
    "    \"\"\"\n",
    "    构造输入模型的特征以及标签。\n",
    "    输入：\n",
    "        corpus： 列表，每个元素是一个句子。\n",
    "        word2idx： 字典，key是单词，value是单词的索引。\n",
    "        seq_len：int, 句子切分后的单词序列的长度。\n",
    "    返回：\n",
    "        sentences：二维列表，每一行是一个句子切分后单词的索引列表（不包括句子的最后一个单词）。输入模型的x。\n",
    "        labels：二维列表，每一行是一个句子切分后单词的索引列表（不包括句子的第一个单词）。y。\n",
    "    \"\"\"\n",
    "    sentences = []\n",
    "    labels = []\n",
    "    for sentence in corpus:\n",
    "        words = sentence.split()\n",
    "        sentence_vec = [0]*seq_len\n",
    "        for i, w in enumerate(words[:-1]):\n",
    "            sentence_vec[i] = word2idx[w]\n",
    "        sentences.append(sentence_vec)\n",
    "        \n",
    "        label_vec = [0] * seq_len\n",
    "        for i, w in enumerate(words[1:]):\n",
    "            label_vec[i] = word2idx[w]\n",
    "        labels.append(label_vec)\n",
    "    return sentences, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data, train_label = build_x_y(train_sentences, word2idx)\n",
    "dev_data, dev_label = build_x_y(dev_sentences, word2idx)\n",
    "test_data, test_label = build_x_y(test_sentences, word2idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1079, 1029, 717, 122, 1259, 413, 1224, 944, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "      [1029, 717, 122, 1259, 413, 1224, 944, 1308, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n"
     ]
    }
   ],
   "source": [
    "# 查看处理后的训练集第一个样本及标签\n",
    "print(train_data[1])\n",
    "print(\" \"*5, train_label[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<s> The girl broke up with Bob . <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>\n",
      "    The girl broke up with Bob . </s> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>\n"
     ]
    }
   ],
   "source": [
    "idx = 1\n",
    "print(\" \".join([idx2word[i] for i in train_data[idx]]))\n",
    "\n",
    "print(\" \"*3, \" \".join([idx2word[i] for i in train_label[idx]]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 构造批次数据\n",
    "def build_batch_data(data, label, batch_size=32):\n",
    "    \"\"\"构建 batch tensor，返回 batch 列表，每个batch为二元组包含data和label\"\"\"\n",
    "    batch_data = []\n",
    "    data_tensor = torch.tensor(data, dtype=torch.long)\n",
    "    label_tensor = torch.tensor(label, dtype=torch.long)\n",
    "    n, dim = data_tensor.size()\n",
    "    for start in range(0, n, batch_size):\n",
    "        end = start + batch_size\n",
    "        if end > n:\n",
    "            dbatch = data_tensor[start: ]\n",
    "            lbatch = label_tensor[start: ]\n",
    "            print(\"最后一个batch size:\", dbatch.size())\n",
    "            break\n",
    "        else:\n",
    "            dbatch = data_tensor[start: end]\n",
    "            lbatch = label_tensor[start: end]\n",
    "        batch_data.append((dbatch, lbatch))\n",
    "    return batch_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "最后一个batch size: torch.Size([20, 21])\n",
      "最后一个batch size: torch.Size([14, 21])\n",
      "最后一个batch size: torch.Size([14, 21])\n"
     ]
    }
   ],
   "source": [
    "train_batch = build_batch_data(train_data, train_label, batch_size=BATCH_SIZE)\n",
    "dev_batch = build_batch_data(dev_data, dev_label, batch_size=BATCH_SIZE)\n",
    "test_batch = build_batch_data(test_data, test_label, batch_size=BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "188 23 23\n"
     ]
    }
   ],
   "source": [
    "# 查看各数据集有多少个batch\n",
    "print(len(train_batch), len(dev_batch), len(test_batch))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2 定义模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义模型\n",
    "class MyLSTM(nn.Module):\n",
    "    def __init__(self, vocab_size, embedding_dim, hidden_dim):\n",
    "        super(MyLSTM, self).__init__()\n",
    "        self.vocab_size = vocab_size\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.word_embeddings = nn.Embedding(self.vocab_size, embedding_dim)\n",
    "        # batch_first=True 意味着输入是(batch, seq, feature)\n",
    "        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)\n",
    "        self.hidden2word = nn.Linear(hidden_dim, self.vocab_size)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        embeds = self.word_embeddings(x)\n",
    "        lstm_out, (h_n, c_n) = self.lstm(embeds)\n",
    "        target_space = self.hidden2word(lstm_out.contiguous().view(-1, self.hidden_dim))\n",
    "        mask = (x != PAD_IDX).view(-1)\n",
    "        mask_target = target_space[mask]\n",
    "        \n",
    "        target_scores = F.log_softmax(mask_target, dim=1)\n",
    "        return target_scores\n",
    "\n",
    "    \n",
    "# 计算准确率\n",
    "def acc_score(pred_score, y):\n",
    "    # 返回最大的概率的索引\n",
    "    y_pred = pred_score.argmax(dim=1)\n",
    "    # print(y.view(-1))\n",
    "    acc_count = torch.eq(y_pred, y.view(-1))\n",
    "    score = acc_count.sum().item() / acc_count.size()[0]\n",
    "    return score\n",
    "\n",
    "\n",
    "# 训练函数\n",
    "def train(model, device, iterator, optimizer, criterion, grad_clip):\n",
    "    epoch_loss = 0  # 积累变量\n",
    "    epoch_acc = 0   # 积累变量\n",
    "    model.train()   # 该函数表示PHASE=Train\n",
    "    \n",
    "    for x, y in iterator:  # 拿每一个minibatch\n",
    "        x = x.to(device)\n",
    "        y = y.to(device)\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        mask = y != PAD_IDX\n",
    "        pure_y = y[mask]\n",
    "        \n",
    "        fx = model(x)                 # 进行forward\n",
    "        loss = criterion(fx, pure_y)  # 计算loss\n",
    "        acc = acc_score(fx, pure_y)   # 计算准确率\n",
    "        loss.backward()               # 进行BP\n",
    "        \n",
    "        # 梯度裁剪\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)\n",
    "        optimizer.step()  # 更新参数\n",
    "        \n",
    "        epoch_loss += loss\n",
    "        epoch_acc += acc\n",
    "        \n",
    "    return epoch_loss/len(iterator),epoch_acc/len(iterator)\n",
    "\n",
    "\n",
    "# 验证函数，验证集和测试集用，不更新梯度\n",
    "def evaluate(model, device, iterator, criterion):\n",
    "    model.eval()  # 不更新参数，预测模式\n",
    "    epoch_loss=0  # 积累变量\n",
    "    epoch_acc=0   # 积累变量\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for x, y in iterator:\n",
    "            x = x.to(device)\n",
    "            y = y.to(device)\n",
    "            mask = y != PAD_IDX\n",
    "            pure_y = y[mask]\n",
    "            \n",
    "            fx = model(x)\n",
    "            loss = criterion(fx, pure_y)\n",
    "            acc = acc_score(fx, pure_y)\n",
    "            epoch_loss += loss\n",
    "            epoch_acc += acc\n",
    "    return epoch_loss/len(iterator), epoch_acc/len(iterator)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3 模型训练与评价"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0, 1, 2, 3]\n"
     ]
    }
   ],
   "source": [
    "VOCAB_SIZE = len(word2idx)     # 词汇表长度\n",
    "\n",
    "model = MyLSTM(VOCAB_SIZE, EMBEDDING_DIM, HIDDEN_DIM)\n",
    "\n",
    "DEVICE = torch.device(\"cuda\" if USE_CUDA else 'cpu')\n",
    "model = model.to(DEVICE)\n",
    "\n",
    "# 使用多块GPU\n",
    "if NUM_CUDA > 1:\n",
    "    device_ids = list(range(NUM_CUDA))\n",
    "    print(device_ids)\n",
    "    model = nn.DataParallel(model, device_ids=device_ids)\n",
    "    # model = nn.parallel.DistributedDataParallel(model, device_ids=device_ids)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 保存最优模型的逻辑是，每一个epoch之后再对比验证集损失值，验证集损失降低才认为模型更优。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/lib/python3.6/site-packages/torch/nn/modules/rnn.py:179: RuntimeWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greatly increasing memory usage. To compact weights again call flatten_parameters().\n",
      "  self.dropout, self.training, self.bidirectional, self.batch_first)\n",
      "/root/anaconda3/lib/python3.6/site-packages/torch/serialization.py:251: UserWarning: Couldn't retrieve source code for container of type MyLSTM. It won't be checked for correctness upon loading.\n",
      "  \"type \" + obj.__name__ + \". It won't be checked \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "save best model: lm-best-dim200.pth\n",
      "Epoch:1|Train Loss:3.9867|Train Acc:0.284045|Val Loss:3.54076|Val Acc:0.322146\n",
      "Epoch:1|Train Loss:3.9867|Train Acc:0.284045|Val Loss:3.54076|Val Acc:0.322146\n",
      "save best model: lm-best-dim200.pth\n",
      "Epoch:2|Train Loss:3.28995|Train Acc:0.32511|Val Loss:3.49253|Val Acc:0.325996\n",
      "Epoch:2|Train Loss:3.28995|Train Acc:0.32511|Val Loss:3.49253|Val Acc:0.325996\n",
      "Epoch:3|Train Loss:2.96528|Train Acc:0.348784|Val Loss:3.55743|Val Acc:0.325569\n",
      "Epoch:4|Train Loss:2.71137|Train Acc:0.374167|Val Loss:3.66872|Val Acc:0.316473\n",
      "Current lr: 0.01\n",
      "Epoch:5|Train Loss:2.50227|Train Acc:0.404593|Val Loss:3.74011|Val Acc:0.320505\n",
      "Epoch:6|Train Loss:2.32764|Train Acc:0.431737|Val Loss:3.86074|Val Acc:0.314687\n",
      "Epoch:7|Train Loss:2.17701|Train Acc:0.460746|Val Loss:3.96021|Val Acc:0.309366\n",
      "Current lr: 0.005\n",
      "Epoch:8|Train Loss:2.05596|Train Acc:0.485682|Val Loss:4.06588|Val Acc:0.305292\n",
      "Early stop!\n"
     ]
    }
   ],
   "source": [
    "criterion = nn.NLLLoss()                                             # 指定损失函数\n",
    "optimizer = optim.Adam(model.parameters(), lr=LEARN_RATE)            # 指定优化器\n",
    "scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.5)   # 学习率缩减\n",
    "\n",
    "model_name = MODEL_PATH.format(EMBEDDING_DIM)\n",
    "LOG_INFO = 'Epoch:{}|Train Loss:{:.6}|Train Acc:{:.6}|Val Loss:{:.6}|Val Acc:{:.6}'\n",
    "\n",
    "SCHED_NUM = 0\n",
    "for epoch in range(1, EPOCHS+1):\n",
    "    train_loss, train_acc = train(model, DEVICE, train_batch, optimizer, criterion, GRAD_CLIP)\n",
    "    valid_loss, valid_acc = evaluate(model, DEVICE, dev_batch, criterion)\n",
    "    if valid_acc > BEST_VALID_ACC: # 如果是最好的模型就保存到文件夹\n",
    "        BEST_VALID_ACC = valid_acc\n",
    "        torch.save(model, model_name)\n",
    "        print(\"save best model:\", model_name)\n",
    "        print(LOG_INFO.format(epoch, train_loss, train_acc, valid_loss, valid_acc))\n",
    "        SCHED_NUM = 0\n",
    "    else:\n",
    "        SCHED_NUM += 1\n",
    "        if SCHED_NUM % 3 == 0:\n",
    "            scheduler.step()\n",
    "            print(\"Current lr:\", optimizer.param_groups[0]['lr'])\n",
    "        if SCHED_NUM == 7:\n",
    "            print(\"Early stop!\")\n",
    "            break\n",
    "    print(LOG_INFO.format(epoch, train_loss, train_acc, valid_loss, valid_acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test Loss: 3.5174615383148193 | Test Acc: 0.3169802856441505 |\n"
     ]
    }
   ],
   "source": [
    "model = torch.load(model_name)\n",
    "test_loss, test_acc = evaluate(model, DEVICE, test_batch, criterion)\n",
    "print('Test Loss: {0} | Test Acc: {1} |'.format(test_loss, test_acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4 打印错误单词"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 答应预测错误的单词\n",
    "def print_pred_error_words(model,device,data_batch):\n",
    "    model.eval()\n",
    "    error_words = []\n",
    "    with torch.no_grad():\n",
    "        for x, y in data_batch:\n",
    "            x = x.to(device)\n",
    "            y = y.to(device)\n",
    "            \n",
    "            mask = (y!=PAD_IDX)\n",
    "            fx = model(x)\n",
    "            \n",
    "            pred_idx = fx.argmax(dim=1)\n",
    "            ground_truth_idx = y[mask]\n",
    "            for p, g in zip(pred_idx.tolist(), ground_truth_idx.tolist()):\n",
    "                if p != g:\n",
    "                    error_words.append(\" | \".join([idx2word[g], idx2word[p]]))\n",
    "    return error_words"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = torch.load(MODEL_PATH.format(EMBEDDING_DIM))\n",
    "error_words = print_pred_error_words(model, DEVICE, test_batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "真实值 | 预测值 | 预测错误次数\n",
      "('Bob | He', 137)\n",
      "('She | He', 109)\n",
      "('Sue | He', 89)\n",
      "('to | .', 43)\n",
      "('and | .', 41)\n",
      "('had | was', 40)\n",
      "('his | the', 37)\n",
      "('decided | was', 37)\n",
      "('for | .', 31)\n",
      "('her | the', 30)\n",
      "(', | .', 28)\n",
      "('His | He', 26)\n",
      "('. | to', 25)\n",
      "('in | .', 25)\n",
      "('One | He', 25)\n",
      "('a | the', 24)\n",
      "('. | the', 23)\n",
      "('and | to', 23)\n",
      "('went | was', 21)\n",
      "('But | He', 21)\n",
      "('Her | He', 21)\n",
      "('The | He', 21)\n",
      "('got | was', 19)\n",
      "('When | He', 19)\n",
      "('They | He', 19)\n",
      "('it | the', 18)\n",
      "('a | to', 17)\n",
      "('! | .', 17)\n",
      "('wanted | was', 17)\n",
      "('she | he', 15)\n",
      "('he | Bob', 15)\n",
      "('her | a', 15)\n",
      "('the | her', 15)\n",
      "('the | .', 15)\n",
      "('he | to', 15)\n"
     ]
    }
   ],
   "source": [
    "words_counter = Counter(error_words)\n",
    "TopN = 35\n",
    "topn_words = words_counter.most_common(TopN)\n",
    "print(\"真实值 | 预测值 | 预测错误次数\")\n",
    "for w in topn_words:\n",
    "    print(w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
